{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.6 深度卷积神经网络（AlexNet）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:45.657048Z",
     "start_time": "2019-03-19T07:36:45.285668Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.0\n",
      "0.2.1\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torchvision\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__)\n",
    "print(torchvision.__version__)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.6.2 AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:45.703036Z",
     "start_time": "2019-03-19T07:36:45.658231Z"
    }
   },
   "outputs": [],
   "source": [
    "class AlexNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            nn.Conv2d(1, 96, 11, 4), # in_channels, out_channels, kernel_size, stride, padding\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2), # kernel_size, stride\n",
    "            # 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数\n",
    "            nn.Conv2d(96, 256, 5, 1, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2),\n",
    "            # 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。\n",
    "            # 前两个卷积层后不使用池化层来减小输入的高和宽\n",
    "            nn.Conv2d(256, 384, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(384, 384, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(384, 256, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2)\n",
    "        )\n",
    "         # 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(256*5*5, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            # 输出层。由于这里使用Fashion-MNIST，所以用类别数为10，而非论文中的1000\n",
    "            nn.Linear(4096, 10),\n",
    "        )\n",
    "\n",
    "    def forward(self, img):\n",
    "        feature = self.conv(img)\n",
    "        output = self.fc(feature.view(img.shape[0], -1))\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:46.053598Z",
     "start_time": "2019-03-19T07:36:45.704356Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AlexNet(\n",
      "  (conv): Sequential(\n",
      "    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))\n",
      "    (1): ReLU()\n",
      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (4): ReLU()\n",
      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (7): ReLU()\n",
      "    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (9): ReLU()\n",
      "    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU()\n",
      "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (fc): Sequential(\n",
      "    (0): Linear(in_features=6400, out_features=4096, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Dropout(p=0.5)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU()\n",
      "    (5): Dropout(p=0.5)\n",
      "    (6): Linear(in_features=4096, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = AlexNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.6.3 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:46.066761Z",
     "start_time": "2019-03-19T07:36:46.054928Z"
    }
   },
   "outputs": [],
   "source": [
    "# 本函数已保存在d2lzh_pytorch包中方便以后使用\n",
    "def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):\n",
    "    \"\"\"Download the fashion mnist dataset and then load into memory.\"\"\"\n",
    "    trans = []\n",
    "    if resize:\n",
    "        trans.append(torchvision.transforms.Resize(size=resize))\n",
    "    trans.append(torchvision.transforms.ToTensor())\n",
    "    \n",
    "    transform = torchvision.transforms.Compose(trans)\n",
    "    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)\n",
    "    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)\n",
    "\n",
    "    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4)\n",
    "\n",
    "    return train_iter, test_iter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:46.091524Z",
     "start_time": "2019-03-19T07:36:46.067835Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "# 如出现“out of memory”的报错信息，可减小batch_size或resize\n",
    "train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.6.4 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-19T07:36:47.850402Z",
     "start_time": "2019-03-19T07:36:46.092485Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.0047, train acc 0.770, test acc 0.865, time 128.3 sec\n",
      "epoch 2, loss 0.0025, train acc 0.879, test acc 0.889, time 128.8 sec\n",
      "epoch 3, loss 0.0022, train acc 0.898, test acc 0.901, time 130.4 sec\n",
      "epoch 4, loss 0.0019, train acc 0.908, test acc 0.900, time 131.4 sec\n",
      "epoch 5, loss 0.0018, train acc 0.913, test acc 0.902, time 129.9 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
